import os
import autograd.numpy as np
import matplotlib.pyplot as plt
from autograd import elementwise_grad

np.random.seed(111)

def H(X):
    H = 2 * (X[:, 0]**2 + X[:, 1]**2 + X[:, 2]**2 + 0.5 * (X[:, 0] * X[:, 1] + X[:, 0] * X[:, 2] + X[:, 1] * X[:, 2])) - np.log(X[:, 0]**2 + 0.02) - np.log(X[:, 1]**2 + 0.02) + 0.5 * (X[:, 3]**2 + X[:, 4]**2 + X[:, 5]**2 + 0.2 * (X[:, 3] * X[:, 4] + X[:, 3] * X[:, 5] + X[:, 4] * X[:, 5]))
    return H

def p_true(x):
    p = np.exp(-H(x)) / 3.24729309964116
    return p

dim = 6
xL = -2
xR = 2
path = './data/'
if not os.path.exists(path):
    os.makedirs(path)

# Data for mae, mape
g = elementwise_grad(H)

# Initialize x
x = (xR - xL) * np.random.rand(10000, dim) + xL
lr = 1e-3
threshold = 1e-5
max_iterations = 1000

for it in range(max_iterations):
    g_x = g(x)
    x -= lr * g_x
    p = p_true(x)
    print("It: {}, min p: {:.2e}".format(it, min(p)))

    if np.all(p > threshold):
        print("Minimum p-value for data: {:.2e}".format(min(p)))
        np.save(path + 'x_error.npy', x)
        break

# Plot x_error data
plt.figure(figsize=(3, 3))
plt.scatter(x[:, 0], x[:, 1], s=6)
plt.title('6D Multi-modal: $(x_1, x_2)$')

plt.xlabel("$x_1$")
plt.ylabel("$x_2$")
plt.xticks(np.linspace(-2, 2, 5))
plt.yticks(np.linspace(-2, 2, 5))
plt.xlim(-2, 2)
plt.ylim(-2, 2)

plt.show()